package com.jstarcraft.ai.jsat.clustering.evaluation;

import static java.lang.Math.exp;
import static java.lang.Math.log;

import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;

/**
 * Adjusted Rand Index (ARI) is a measure to evaluate a cluster based on the
 * true class labels for the data set. The ARI normally returns a value in [-1,
 * 1], where 0 indicates the clustering appears random, and 1 indicate the
 * clusters perfectly match the class labels, and negative values indicate a
 * clustering that is worse than random. To match the {@link ClusterEvaluation}
 * interface, the value returned by evaluate will be 1.0-Adjusted Rand Index so
 * the best value becomes 0.0 and the worse value becomes 2.0. <br>
 * <b>NOTE:</b> Because the ARI needs to know the true class labels, only
 * {@link #evaluate(int[], com.jstarcraft.ai.jsat.DataSet) } will work, since it
 * provides the data set as an argument. The dataset given must be an instance
 * of {@link ClassificationDataSet}
 * 
 * @author Edward Raff
 */
public class AdjustedRandIndex implements ClusterEvaluation {

    @Override
    public double evaluate(int[] designations, DataSet dataSet) {
        if (!(dataSet instanceof ClassificationDataSet))
            throw new RuntimeException("NMI can only be calcuate for classification data sets");
        ClassificationDataSet cds = (ClassificationDataSet) dataSet;
        int clusters = 0;// how many clusters are there?
        for (int clusterID : designations)
            clusters = Math.max(clusterID + 1, clusters);
        double[] truthSums = new double[cds.getClassSize()];
        double[] clusterSums = new double[clusters];
        double[][] table = new double[clusterSums.length][truthSums.length];
        double n = 0.0;
        for (int i = 0; i < designations.length; i++) {
            int cluster = designations[i];
            if (cluster < 0)
                continue;// noisy point
            int label = cds.getDataPointCategory(i);
            double weight = cds.getWeight(i);
            table[cluster][label] += weight;
            truthSums[label] += weight;
            clusterSums[cluster] += weight;
            n += weight;
        }

        /*
         * Adjusted Rand Index involves many (n choose 2) = 1/2 (n-1) n
         */

        double sumAllTable = 0.0;
        double addCTerm = 0.0, addLTerm = 0.0;// clustering and label

        for (int i = 0; i < table.length; i++) {
            double a_i = clusterSums[i];
            addCTerm += a_i * (a_i - 1) / 2;

            for (int j = 0; j < table[i].length; j++) {
                if (i == 0) {
                    double b_j = truthSums[j];
                    addLTerm += b_j * (b_j - 1) / 2;
                }

                double n_ij = table[i][j];
                double n_ij_c2 = n_ij * (n_ij - 1) / 2;
                sumAllTable += n_ij_c2;
            }
        }

        double longMultTerm = exp(log(addCTerm) + log(addLTerm) - (log(n) + log(n - 1) - log(2)));// numericaly more stable verison
        return 1.0 - (sumAllTable - longMultTerm) / (addCTerm / 2 + addLTerm / 2 - longMultTerm);
    }

    @Override
    public double naturalScore(double evaluate_score) {
        // returns values int he range of [1, -1], with 1=best, and -1=worst
        return -evaluate_score + 1;
    }

    @Override
    public double evaluate(List<List<DataPoint>> dataSets) {
        throw new UnsupportedOperationException("Adjusted Rand Index requires the true data set" + " labels, call evaluate(int[] designations, DataSet dataSet)" + " instead");
    }

    @Override
    public ClusterEvaluation clone() {
        return new AdjustedRandIndex();
    }
}
